#!/usr/bin/env python
# ===============================================================
# knowledge_distill_agnews.py
# ===============================================================
"""
Knowledge-distillation demo: AG News  ➜  Quantum VQC  ➜  Smaller VQC
Author: <you>
"""

# ---------------------------------------------------------------
# 0. Imports
# ---------------------------------------------------------------
import numpy as np
from typing import Tuple
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from qiskit import QuantumCircuit
from qiskit.circuit.library import EfficientSU2
from qiskit_aer import AerSimulator
from qiskit.utils import QuantumInstance
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC

SEED = 123                # single RNG seed for full reproducibility
VECTOR_DIM = 1024         # TF-IDF vocabulary size (== teacher qubits)
STUDENT_QUBITS = 6        # lighter model
TRAIN_RATIO = 0.8

# ---------------------------------------------------------------
# 1. Helper: build a VQC of arbitrary size
# ---------------------------------------------------------------
def build_vqc(num_qubits: int, num_classes: int, reps: int = 2) -> VQC:
    """
    Returns an untrained VQC with simple angle encoding
    and an EfficientSU2 ansatz of given depth (reps).
    """
    # --- Feature map ---
    def feature_map(x):
        qc = QuantumCircuit(num_qubits)
        # take only the first num_qubits TF-IDF components
        for i, theta in enumerate(x[:num_qubits]):
            qc.ry(theta, i)
        return qc

    # --- Ansatz ---
    ansatz = EfficientSU2(num_qubits, reps=reps, entanglement="full")

    # --- QNN wrapper & VQC ---
    backend = AerSimulator(seed_simulator=SEED)
    qi      = QuantumInstance(backend=backend,
                              seed_simulator=SEED,
                              seed_transpiler=SEED)
    qnn = SamplerQNN(circuit=ansatz,
                     input_params=ansatz.parameters[:num_qubits],
                     weight_params=ansatz.parameters[num_qubits:],
                     quantum_instance=qi)

    return VQC(feature_map=feature_map,
               ansatz=ansatz,
               optimizer="COBYLA",
               quantum_instance=qi,
               num_classes=num_classes)

# ---------------------------------------------------------------
# 2. “Server” that does everything
# ---------------------------------------------------------------
class DistillationServer:
    def __init__(self,
                 vector_dim: int = VECTOR_DIM,
                 student_qubits: int = STUDENT_QUBITS):
        self.vector_dim     = vector_dim
        self.student_qubits = student_qubits
        self._prepare_data()

        # Teacher is “large”: same qubits as vector size, deeper ansatz
        self.teacher = build_vqc(num_qubits=self.vector_dim,
                                 num_classes=4,
                                 reps=2)

        # Student is “small”: fewer qubits, shallower ansatz
        self.student = build_vqc(num_qubits=self.student_qubits,
                                 num_classes=4,
                                 reps=1)

    # ----------------------- data -------------------------------
    def _prepare_data(self):
        ds = load_dataset("ag_news")
        texts   = ds["train"]["text"] + ds["test"]["text"]
        labels  = ds["train"]["label"] + ds["test"]["label"]

        # TF-IDF → [0,π] scaling for angle encoding
        vect = TfidfVectorizer(max_features=self.vector_dim,
                               stop_words="english")
        X = vect.fit_transform(texts).toarray().astype(np.float32)
        X = np.pi * X / (X.max() + 1e-12)
        y = np.array(labels, dtype=int)

        (self.train_X,
         self.test_X,
         self.train_y,
         self.test_y) = train_test_split(X, y,
                                         train_size=TRAIN_RATIO,
                                         stratify=y,
                                         random_state=SEED)

    # ------------------- training routines ----------------------
    def train_teacher(self):
        print("Training TEACHER VQC …")
        self.teacher.fit(self.train_X, self.train_y)

    def distill_student(self, use_soft=False, T: float = 2.0):
        """
        Train the student to mimic the teacher.
        *Hard* KD: student uses teacher arg-max labels (default).
        *Soft*  KD: student uses teacher soft probabilities / temperature T.
        """

        # ---------- generate teacher outputs ----------
        if use_soft:
            # soft (logits) via teacher._neural_network forward pass
            soft_probs = []
            for x in self.train_X:
                circ_list = self.teacher._neural_network.construct_circuit(x)
                # For a classification VQC, the probability of class k
                # is proportional to (1-⟨Z⟩ₖ)/2  – replicate quickly:
                class_p = []
                for circ in circ_list:
                    result = self.teacher.quantum_instance.execute(circ)
                    counts = result.get_counts()
                    p0 = counts.get("0"*circ.num_qubits, 0) / result.trials
                    class_p.append((1 - p0)/2)
                soft_probs.append(class_p)
            soft_probs = np.array(soft_probs, dtype=np.float32)
            # temperature-scaled soft targets
            soft_targets = np.exp(np.log(soft_probs + 1e-12) / T)
            soft_targets /= soft_targets.sum(axis=1, keepdims=True)
            labels_for_student = np.argmax(soft_targets, axis=1)
        else:
            # hard labels = teacher predictions
            labels_for_student = self.teacher.predict(self.train_X)

        # ---------- train student ----------
        print("Distilling STUDENT VQC …")
        self.student.fit(self.train_X, labels_for_student)

    # ---------------------- metrics -----------------------------
    def evaluate(self) -> Tuple[float, float]:
        teach_acc = self.teacher.score(self.test_X, self.test_y)
        stud_acc  = self.student.score(self.test_X, self.test_y)
        return teach_acc, stud_acc

# ---------------------------------------------------------------
# 3. Entry-point
# ---------------------------------------------------------------
def main():
    server = DistillationServer()
    server.train_teacher()
    print("\n==> Teacher-only accuracy: {:.3f}".format(
          server.evaluate()[0]))

    # ----- distillation -----
    server.distill_student(use_soft=False)  # hard-label KD
    teach_acc, stud_acc = server.evaluate()
    print("\n==> After distillation:")
    print("    Teacher accuracy : {:.3f}".format(teach_acc))
    print("    Student accuracy : {:.3f}".format(stud_acc))

if __name__ == "__main__":
    main()
